import torch,math


class Conv2d_mat_BUG_adaptive(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1,bias = True,start_rank_percent = 0.1,groups = None,device = None,tau = 0.1)->None:

        """  
        Initializer for the convolutional low rank layer (filterwise), extention of the classical Pytorch's convolutional layer.
        INPUTS:
        in_channels: number of input channels (Pytorch's standard)
        out_channels: number of output channels (Pytorch's standard)
        kernel_size : kernel_size for the convolutional filter (Pytorch's standard)
        dilation : dilation of the convolution (Pytorch's standard)
        padding : padding of the convolution (Pytorch's standard)
        stride : stride of the filter (Pytorch's standard)
        bias  : flag variable for the bias to be included (Pytorch's standard)
        step : string variable ('K','L' or 'S') for which forward phase to use
        rank : rank variable, None if the layer has to be treated as a classical Pytorch Linear layer (with weight and bias). If
                it is an int then it's either the starting rank for adaptive or the fixed rank for the layer.
        fixed : flag variable, True if the rank has to be fixed (KLS training on this layer)
        load_weights : variables to load (Pytorch standard, to finish)
        dtype : Type of the tensors (Pytorch standard, to finish)
        """
            
        super(Conv2d_mat_BUG_adaptive, self).__init__()

        self.kernel_size = [kernel_size, kernel_size] if isinstance(kernel_size,int) else kernel_size
        self.kernel_size_number = kernel_size * kernel_size
        self.out_channels = out_channels
        self.dilation = dilation if type(dilation)==tuple else (dilation, dilation)
        self.padding = padding if type(padding) == tuple else(padding, padding)
        self.stride = (stride if type(stride)==tuple else (stride, stride))
        self.in_channels = in_channels
        self.device = device
        self.tau = tau
        self.augmented_basis = False
        self.maximal_rank = int(min([self.out_channels, self.in_channels*self.kernel_size_number]) / 2)
        self.rank = start_rank_percent if isinstance(start_rank_percent,int) else int(min([self.out_channels,self.in_channels*self.kernel_size_number])*start_rank_percent)


        self.U = torch.nn.Parameter(torch.randn(size = (self.out_channels,self.maximal_rank),device=device) ,requires_grad=True)                                                 
        self.V = torch.nn.Parameter(torch.randn(size = (self.kernel_size_number*self.in_channels,self.maximal_rank),device = device),requires_grad=True)
        self.S = torch.nn.Parameter(torch.randn(size = (self.maximal_rank,self.maximal_rank),device = device),requires_grad = True)
        self.U1 = torch.nn.Parameter(torch.randn(size = (self.out_channels, self.maximal_rank),device = device))
        self.V1 = torch.nn.Parameter(torch.randn(size = (self.in_channels*self.kernel_size_number, self.maximal_rank),device = device)) 
        self.Us = [self.U,self.V]
        self.C = self.S

        if bias:
            self.bias = torch.nn.Parameter(torch.empty(self.out_channels))
        else:
            self.bias = torch.nn.Parameter(torch.zeros(self.out_channels))

        self.reset_parameters()


    @torch.no_grad()
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.S, a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.U, a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))

        # Orthonormalize bases
        self.U.data, _ = torch.linalg.qr(self.U.data, 'reduced')
        self.V.data, _ = torch.linalg.qr(self.V.data, 'reduced')

    # @torch.no_grad()
    # def reset_parameters(self):

    #     w  = torch.randn((self.out_channels,self.in_channels*self.kernel_size_number))
    #     torch.nn.init.kaiming_uniform_(w, a=math.sqrt(5))
    #     if self.bias is not None:
    #         fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(w)
    #         bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
    #         torch.nn.init.uniform_(self.bias, -bound, bound)
    #     u,s,v = torch.linalg.svd(w,full_matrices = True)
    #     self.U.data,self.S.data,self.V.data = u[:,:self.maximal_rank],torch.diag(s[:self.maximal_rank]),v[:,:self.maximal_rank]
    #     del w,u,s,v


    def forward(self, input):

        """  
        forward phase for the convolutional layer. It has to contain the three different
        phases for the steps 'K','L' and 'S' in order to be optimizable using dlrt.

        """
        
        batch_size,_,_,_ = input.shape

        if not self.augmented_basis:
            r1 = self.rank
        else:
            r1 = min(self.maximal_rank,2*self.rank)
        

        U_hat,S_hat,V_hat = self.U[:,:r1],self.S[:r1,:r1],self.V[:,:r1]
        inp_unf = torch.nn.functional.unfold(input,self.kernel_size,padding = self.padding,stride = self.stride).to(self.device)

        out_h = int(math.floor(((input.shape[2]+2*self.padding[0]-self.dilation[0]*(self.kernel_size[0]-1)-1)/self.stride[0])+1))
        out_w = int(math.floor(((input.shape[3]+2*self.padding[1]-self.dilation[1]*(self.kernel_size[1]-1)-1)/self.stride[1])+1))

        out_unf = (inp_unf.transpose(1, 2).matmul(V_hat) )
        out_unf = (out_unf.matmul(S_hat.t()))
        if self.bias is not None:
            out_unf = (out_unf.matmul(U_hat.t()) + self.bias).transpose(1, 2)
        else:
            out_unf = (out_unf.matmul(U_hat.t())).transpose(1, 2)

        return out_unf.view(batch_size, self.out_channels, out_h, out_w)
    
    @torch.no_grad()
    def s_step_preprocess(self, adaptive=True):
        r = self.rank
        if adaptive:
            r1 = min(2*r,self.maximal_rank)
        else:
            r1 = r
        M = torch.matmul(self.U1[:,:r1].T, self.U[:,:r])
        N = torch.matmul(self.V[:,:r].T, self.V1[:,:r1])
        S0 = M@self.S[:r,:r]@N #torch.matmul(M, self.S[:r,:r])
        # S0 = S0@N#torch.matmul(S0, N)
        self.S.data[:r1,:r1] = S0
        # update basis
        self.U.data[:,:r1] = self.U1[:,:r1]
        self.V.data[:,:r1] = self.V1[:,:r1]

    @torch.no_grad()
    def step(self, dlrt_step, adaptive=True,lr = 0.05):
        r = self.rank
        if adaptive:
            r1 = min(2*r,self.maximal_rank)
        else:
            r1 = r

        if dlrt_step == "K":
            U0 = self.U[:,:r]
            # S0 = self.S[:r,:r]
            if adaptive:
                U1, _ = torch.linalg.qr(torch.cat((U0, self.U.grad[:,:r]),1), 'reduced')   ##@S0+U0@self.S.grad[:r,:r]   ## alternative
            else:
                U1, _ = torch.linalg.qr(self.U.grad[:,:r], 'reduced')
            self.U1.data[:,:r1] = U1[:,:r1]
        # elif dlrt_step == "L":
            V0 = self.V[:,:r]
            # S0 = self.S[:r,:r]
            if adaptive:
                V1, _ = torch.linalg.qr(torch.cat((V0, self.V.grad[:,:r]),1), 'reduced')  ## @S0+V0@self.S.grad[:r,:r] alternative
            else:
                V1, _ = torch.linalg.qr(self.V.grad[:,:r], 'reduced')
            self.V1.data[:,:r1] = V1[:,:r1]
        elif dlrt_step == "S":
            self.Truncate()

    @torch.no_grad()
    def Truncate(self):
        r0 = self.rank
        tol = self.tau
        r1 = min(2*r0, self.maximal_rank)
        P, d, Q = torch.linalg.svd(self.S[:r1, :r1])

        tol = tol * torch.linalg.norm(d) 
        max_r = min(2*r0, self.maximal_rank)#2*r0
        for j in range(0, max_r):
            tmp = torch.linalg.norm(d[j:max_r])
            if tmp < tol:
                r1 = j
                break
        r1 = min(r1, self.maximal_rank)
        r1 = max(r1, 2)

        # update s
        self.S.data[:r1, :r1] = torch.diag(d[:r1])

        # update u and v
        self.U.data[:, :r1] = torch.matmul(self.U[:, :2 * r0], P[:, :r1])
        self.V.data[:, :r1] = torch.matmul(self.V[:, :2 * r0], Q.T[:, :r1])
        self.rank = int(r1)

    @torch.no_grad()
    def activate_grad_step(self, step="S"):
        if step == "K":
            self.U.requires_grad = True
            self.V.requires_grad = True
            self.S.requires_grad = False
        if step == "S":
            self.U.requires_grad = False
            self.V.requires_grad = False
            self.S.requires_grad = True

    @torch.no_grad()
    def copy_U(self):
        self.U1.data[:, :self.rank] = self.U.data[:, :self.rank].clone()
    
    @torch.no_grad()
    def set_grad_zero(self):
        self.U.grad.zero_()
        self.V.grad.zero_()
        # self.S.grad.zero_()